"""
Phase 11: Reviewer #2 Response — Trilemma Paper
=================================================
Part 1: EMU magnitude robustness (IQR/SD effects, winsorize, leave-one-out)
Part 2: Logit hardening (region FE, income-group FE)
Part 3: Projection sensitivity (±1SE coefficient bands on strain)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT = PROJECT_DIR / "output" / "tables"
OUT.mkdir(parents=True, exist_ok=True)
MULTILATERAL_DATA = ROOT_DIR / "multilateral" / "followup" / "data" / "processed"

EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}

# World Bank regions (simplified mapping)
REGION_MAP = {
    # East Asia & Pacific
    'AUS': 'EAP', 'BRN': 'EAP', 'CHN': 'EAP', 'FJI': 'EAP', 'HKG': 'EAP',
    'IDN': 'EAP', 'JPN': 'EAP', 'KHM': 'EAP', 'KOR': 'EAP', 'LAO': 'EAP',
    'MAC': 'EAP', 'MHL': 'EAP', 'MMR': 'EAP', 'MNG': 'EAP', 'MYS': 'EAP',
    'NZL': 'EAP', 'PHL': 'EAP', 'PLW': 'EAP', 'PNG': 'EAP', 'SGP': 'EAP',
    'SLB': 'EAP', 'THA': 'EAP', 'TLS': 'EAP', 'TON': 'EAP', 'TUV': 'EAP',
    'TWN': 'EAP', 'VNM': 'EAP', 'VUT': 'EAP', 'WSM': 'EAP',
    # Europe & Central Asia
    'ALB': 'ECA', 'AND': 'ECA', 'ARM': 'ECA', 'AUT': 'ECA', 'AZE': 'ECA',
    'BEL': 'ECA', 'BGR': 'ECA', 'BIH': 'ECA', 'BLR': 'ECA', 'CHE': 'ECA',
    'CYP': 'ECA', 'CZE': 'ECA', 'DEU': 'ECA', 'DNK': 'ECA', 'ESP': 'ECA',
    'EST': 'ECA', 'FIN': 'ECA', 'FRA': 'ECA', 'GBR': 'ECA', 'GEO': 'ECA',
    'GRC': 'ECA', 'HRV': 'ECA', 'HUN': 'ECA', 'IRL': 'ECA', 'ISL': 'ECA',
    'ITA': 'ECA', 'KAZ': 'ECA', 'KGZ': 'ECA', 'LTU': 'ECA', 'LUX': 'ECA',
    'LVA': 'ECA', 'MDA': 'ECA', 'MKD': 'ECA', 'MLT': 'ECA', 'MNE': 'ECA',
    'NLD': 'ECA', 'NOR': 'ECA', 'POL': 'ECA', 'PRT': 'ECA', 'ROU': 'ECA',
    'RUS': 'ECA', 'SRB': 'ECA', 'SVK': 'ECA', 'SVN': 'ECA', 'SWE': 'ECA',
    'TJK': 'ECA', 'TKM': 'ECA', 'TUR': 'ECA', 'UKR': 'ECA', 'UZB': 'ECA',
    'XKX': 'ECA',
    # Latin America & Caribbean
    'ARG': 'LAC', 'ATG': 'LAC', 'BHS': 'LAC', 'BLZ': 'LAC', 'BOL': 'LAC',
    'BRA': 'LAC', 'BRB': 'LAC', 'CHL': 'LAC', 'COL': 'LAC', 'CRI': 'LAC',
    'CUB': 'LAC', 'DMA': 'LAC', 'DOM': 'LAC', 'ECU': 'LAC', 'GRD': 'LAC',
    'GTM': 'LAC', 'GUY': 'LAC', 'HND': 'LAC', 'HTI': 'LAC', 'JAM': 'LAC',
    'KNA': 'LAC', 'LCA': 'LAC', 'MEX': 'LAC', 'NIC': 'LAC', 'PAN': 'LAC',
    'PER': 'LAC', 'PRY': 'LAC', 'SLV': 'LAC', 'SUR': 'LAC', 'TTO': 'LAC',
    'URY': 'LAC', 'VCT': 'LAC', 'VEN': 'LAC',
    # Middle East & North Africa
    'ARE': 'MENA', 'BHR': 'MENA', 'DJI': 'MENA', 'DZA': 'MENA', 'EGY': 'MENA',
    'IRN': 'MENA', 'IRQ': 'MENA', 'ISR': 'MENA', 'JOR': 'MENA', 'KWT': 'MENA',
    'LBN': 'MENA', 'LBY': 'MENA', 'MAR': 'MENA', 'MLI': 'MENA', 'OMN': 'MENA',
    'PSE': 'MENA', 'QAT': 'MENA', 'SAU': 'MENA', 'SYR': 'MENA', 'TUN': 'MENA',
    'YEM': 'MENA',
    # North America
    'CAN': 'NAC', 'USA': 'NAC',
    # South Asia
    'AFG': 'SAS', 'BGD': 'SAS', 'BTN': 'SAS', 'IND': 'SAS', 'LKA': 'SAS',
    'MDV': 'SAS', 'NPL': 'SAS', 'PAK': 'SAS',
    # Sub-Saharan Africa
    'AGO': 'SSA', 'BDI': 'SSA', 'BEN': 'SSA', 'BFA': 'SSA', 'BWA': 'SSA',
    'CAF': 'SSA', 'CIV': 'SSA', 'CMR': 'SSA', 'COD': 'SSA', 'COG': 'SSA',
    'COM': 'SSA', 'CPV': 'SSA', 'ERI': 'SSA', 'ETH': 'SSA', 'GAB': 'SSA',
    'GHA': 'SSA', 'GIN': 'SSA', 'GMB': 'SSA', 'GNB': 'SSA', 'GNQ': 'SSA',
    'KEN': 'SSA', 'LBR': 'SSA', 'LSO': 'SSA', 'MDG': 'SSA', 'MOZ': 'SSA',
    'MRT': 'SSA', 'MUS': 'SSA', 'MWI': 'SSA', 'NAM': 'SSA', 'NER': 'SSA',
    'NGA': 'SSA', 'RWA': 'SSA', 'SDN': 'SSA', 'SEN': 'SSA', 'SLE': 'SSA',
    'SOM': 'SSA', 'SSD': 'SSA', 'STP': 'SSA', 'SWZ': 'SSA', 'SYC': 'SSA',
    'TCD': 'SSA', 'TGO': 'SSA', 'TZA': 'SSA', 'UGA': 'SSA', 'ZAF': 'SSA',
    'ZMB': 'SSA', 'ZWE': 'SSA',
}

# Income groups based on GDP per capita thresholds (approximate WB 2024)
# Will be computed dynamically from gdp_pc_ppp


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def get_ez_df(df):
    """Filter to eurozone members post-join-year."""
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        ez_rows.append(df[mask])
    return pd.concat(ez_rows, ignore_index=True)


def run_panel_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    result['_gls'] = gls
    return result


def run_logit(df, y_var, x_vars, label, dummy_vars=None):
    """Pooled logit with Hessian SEs. Optional dummy_vars added as fixed effects."""
    from scipy.optimize import minimize
    from scipy import stats as sp_stats

    cols = [y_var] + x_vars + ['iso3']
    if dummy_vars:
        cols += dummy_vars
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values.astype(float)

    # Build X matrix: intercept + x_vars + dummy_vars
    X_parts = [np.ones((len(sub), 1)), sub[x_vars].values.astype(float)]
    all_var_names = list(x_vars)

    if dummy_vars:
        for dv in dummy_vars:
            X_parts.append(sub[[dv]].values.astype(float))
            all_var_names.append(dv)

    X = np.column_stack(X_parts) if len(X_parts) > 1 else X_parts[0]
    n, k = X.shape

    if y.sum() < 5 or (1 - y).sum() < 5:
        print(f"  {label}: insufficient outcome variation, skipping")
        return None

    # Standardize (except intercept)
    x_means = X[:, 1:].mean(axis=0)
    x_stds = X[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = X.copy()
    X_std[:, 1:] = (X[:, 1:] - x_means) / x_stds

    def neg_ll(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))

    def grad(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_std.T @ (y - p)

    try:
        opt = minimize(neg_ll, np.zeros(k), jac=grad,
                       method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        beta_std = opt.x
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)

        # Hessian SEs
        z = X @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        W = p * (1 - p)
        H = X.T @ (X * W[:, None])
        try:
            V = np.linalg.inv(H)
            se = np.sqrt(np.maximum(np.diag(V), 0))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)

        t_stats = beta / se
        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(t_stats)))

        # Pseudo-R²
        ll_model = -neg_ll(beta_std)
        p_bar = y.mean()
        ll_null = n * (p_bar * np.log(p_bar + 1e-12) +
                       (1 - p_bar) * np.log(1 - p_bar + 1e-12))
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

        # Marginal effects at mean
        z_mean = X.mean(axis=0) @ beta
        p_mean = 1 / (1 + np.exp(-z_mean))
        mfx = beta[1:] * p_mean * (1 - p_mean)

    except Exception as e:
        print(f"  {label}: logit failed ({e}), skipping")
        return None

    res = {
        'model': label,
        'n_obs': n,
        'n_countries': sub['iso3'].nunique(),
        'r_squared': pseudo_r2,
    }

    print(f"\n  {label} (N={n}, Pseudo-R²={pseudo_r2:.4f}) [Logit]")
    for i, name in enumerate(all_var_names):
        sig = stars(pvalues[i + 1])
        print(f"    {name:30s} β={beta[i+1]:8.4f} (se={se[i+1]:.4f}) {sig}  "
              f"[MFX={mfx[i]:.5f}]")
        res[f'{name}_coef'] = mfx[i]
        res[f'{name}_se'] = se[i + 1] * p_mean * (1 - p_mean)
        res[f'{name}_p'] = pvalues[i + 1]

    return res


# ═══════════════════════════════════════════════════════════════════════
# PART 1: EMU MAGNITUDE ROBUSTNESS
# ═══════════════════════════════════════════════════════════════════════

def emu_magnitude_robustness(df):
    """IQR/SD-standardized effects, winsorized CA, leave-one-out."""
    print("\n" + "=" * 70)
    print("PART 1: EMU MAGNITUDE ROBUSTNESS")
    print("=" * 70)

    ez_df = get_ez_df(df)
    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Compute deviations
    emu_means = ez_df.groupby('year')[z_vars + ['ca_gdp']].mean()
    emu_means.columns = [f'{c}_emu_mean' for c in emu_means.columns]
    ez = ez_df.merge(emu_means, on='year', how='left')
    for z in z_vars:
        ez[f'{z}_dev'] = ez[z] - ez[f'{z}_emu_mean']
    ez['ca_dev'] = ez['ca_gdp'] - ez['ca_gdp_emu_mean']

    z_dev_vars = [f'{z}_dev' for z in z_vars]

    # ── 1a. Report IQR and SD of Z1_dev and CA_dev ──
    print("\n  -- Descriptive Statistics --")
    sub = ez[z_dev_vars + ['ca_dev']].dropna()
    for v in z_dev_vars + ['ca_dev']:
        q25, q75 = sub[v].quantile([0.25, 0.75])
        sd = sub[v].std()
        print(f"    {v:15s}: SD={sd:.4f}, IQR=[{q25:.4f}, {q75:.4f}], "
              f"IQR_width={q75-q25:.4f}")

    z1_dev_sd = sub['Z_1_dev'].std()
    z1_dev_iqr = sub['Z_1_dev'].quantile(0.75) - sub['Z_1_dev'].quantile(0.25)

    # ── 1b. Baseline Z_dev regression ──
    baseline = run_panel_gls(ez, 'ca_dev', z_dev_vars + controls, 'Baseline')
    if baseline:
        z1_coef = baseline['Z_1_dev_coef']
        z1_se = baseline['Z_1_dev_se']
        sd_effect = z1_coef * z1_dev_sd
        iqr_effect = z1_coef * z1_dev_iqr
        print(f"\n  Z1_dev coefficient: {z1_coef:.2f}")
        print(f"  1-SD effect on CA_dev: {sd_effect:.2f} pp")
        print(f"  IQR effect on CA_dev: {iqr_effect:.2f} pp")

    # ── 1c. Winsorized CA_dev (p5/p95) ──
    print("\n  -- Winsorized CA_dev --")
    p5, p95 = ez['ca_dev'].quantile([0.05, 0.95])
    print(f"    Winsorizing CA_dev at [{p5:.2f}, {p95:.2f}]")
    ez['ca_dev_wins'] = ez['ca_dev'].clip(p5, p95)
    winsorized = run_panel_gls(ez, 'ca_dev_wins', z_dev_vars + controls,
                                'Winsorized CA')

    # ── 1d. Winsorized CA_dev (p1/p99 — less aggressive) ──
    p1, p99 = ez['ca_dev'].quantile([0.01, 0.99])
    print(f"    Also winsorizing at [{p1:.2f}, {p99:.2f}]")
    ez['ca_dev_wins99'] = ez['ca_dev'].clip(p1, p99)
    wins99 = run_panel_gls(ez, 'ca_dev_wins99', z_dev_vars + controls,
                            'Wins p1/p99')

    # ── 1e. Leave-one-out ──
    print("\n  -- Leave-One-Out --")
    loo_results = []
    for drop_iso in sorted(EUROZONE_ISO3):
        ez_loo = ez[ez['iso3'] != drop_iso].copy()
        r = run_panel_gls(ez_loo, 'ca_dev', z_dev_vars + controls,
                          f'Drop {drop_iso}')
        if r:
            loo_results.append({
                'dropped': drop_iso,
                'Z1_dev_coef': r['Z_1_dev_coef'],
                'Z1_dev_se': r['Z_1_dev_se'],
                'Z1_dev_p': r['Z_1_dev_p'],
                'n_obs': r['n_obs'],
                'n_countries': r['n_countries'],
                'r_squared': r['r_squared'],
            })

    loo_df = pd.DataFrame(loo_results)
    print("\n  Leave-One-Out Summary:")
    print(f"    Z1_dev coef range: [{loo_df['Z1_dev_coef'].min():.2f}, "
          f"{loo_df['Z1_dev_coef'].max():.2f}]")
    print(f"    Baseline: {baseline['Z_1_dev_coef']:.2f}" if baseline else "")
    print(f"    Most influential drop: "
          f"{loo_df.loc[loo_df['Z1_dev_coef'].abs().idxmin(), 'dropped']} "
          f"(smallest |coef|)")

    # ── Write output ──
    lines = ["# EMU Magnitude Robustness\n"]

    # Standardized effects table
    lines.append("## Standardized Effects\n")
    if baseline:
        lines.append("| Metric | Value |")
        lines.append("|:---|---:|")
        lines.append(f"| Z1_dev coefficient | {z1_coef:.2f} |")
        lines.append(f"| Z1_dev SD (within-EMU) | {z1_dev_sd:.4f} |")
        lines.append(f"| Z1_dev IQR width | {z1_dev_iqr:.4f} |")
        lines.append(f"| 1-SD effect on CA/GDP deviation | {sd_effect:.2f} pp |")
        lines.append(f"| IQR effect on CA/GDP deviation | {iqr_effect:.2f} pp |")
        lines.append(f"| 1-SE (coefficient) | {z1_se:.2f} |")

    # Winsorization table
    lines.append("\n## Winsorization Robustness\n")
    results_list = []
    if baseline: results_list.append(baseline)
    if winsorized: results_list.append(winsorized)
    if wins99: results_list.append(wins99)

    if results_list:
        lines.append("| Specification | Z1_dev | SE | p | N | R² |")
        lines.append("|:---|---:|---:|---:|---:|---:|")
        for r in results_list:
            sig = stars(r['Z_1_dev_p'])
            lines.append(f"| {r['model']} | {r['Z_1_dev_coef']:.2f}{sig} | "
                        f"({r['Z_1_dev_se']:.2f}) | {r['Z_1_dev_p']:.4f} | "
                        f"{r['n_obs']} | {r['r_squared']:.4f} |")

    # LOO table
    lines.append("\n## Leave-One-Out\n")
    lines.append("| Dropped Country | Z1_dev | SE | p | N |")
    lines.append("|:---|---:|---:|---:|---:|")
    for _, row in loo_df.iterrows():
        sig = stars(row['Z1_dev_p'])
        lines.append(f"| {row['dropped']} | {row['Z1_dev_coef']:.2f}{sig} | "
                    f"({row['Z1_dev_se']:.2f}) | {row['Z1_dev_p']:.4f} | "
                    f"{int(row['n_obs'])} |")
    if baseline:
        sig = stars(baseline['Z_1_dev_p'])
        lines.append(f"| **Full sample** | **{baseline['Z_1_dev_coef']:.2f}{sig}** | "
                    f"**({baseline['Z_1_dev_se']:.2f})** | "
                    f"**{baseline['Z_1_dev_p']:.4f}** | **{baseline['n_obs']}** |")

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "EMU members post-accession. Z deviations from EMU-year means.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT / "phase11_emu_robustness.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    return baseline


# ═══════════════════════════════════════════════════════════════════════
# PART 2: LOGIT HARDENING (REGION AND INCOME FE)
# ═══════════════════════════════════════════════════════════════════════

def logit_hardening(df):
    """Add region FE and income-group FE to peg-vs-float logit."""
    print("\n" + "=" * 70)
    print("PART 2: LOGIT HARDENING — REGION AND INCOME FE")
    print("=" * 70)

    # Create region dummies
    df['region'] = df['iso3'].map(REGION_MAP)
    print(f"  Region coverage: {df['region'].notna().sum()} / {len(df)} obs")
    print(f"  Regions: {sorted(df['region'].dropna().unique())}")

    # Create income group from gdp_pc_ppp (time-varying)
    if 'gdp_pc_ppp' in df.columns:
        # WB thresholds (approx 2024): low <1145, lower-mid 1146-4515,
        # upper-mid 4516-14005, high >14005
        df['income_group'] = pd.cut(
            df['gdp_pc_ppp'],
            bins=[0, 4515, 14005, np.inf],
            labels=['low_mid', 'upper_mid', 'high']
        )
        print(f"  Income group distribution:")
        print(f"    {df['income_group'].value_counts().to_dict()}")
    else:
        df['income_group'] = None

    # Binary peg vs float sample
    binary_df = df[df['regime_3cat'].isin([1, 3])].copy()
    binary_df['is_peg_binary'] = (binary_df['regime_3cat'] == 1).astype(float)

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']
    x_demo = ['Z_1', 'Z_2', 'Z_3']

    results = []

    # M1: Baseline pooled logit (replicate)
    r = run_logit(binary_df, 'is_peg_binary', x_demo + controls, 'Pooled')
    if r: results.append(r)

    # M2: With region dummies
    region_sub = binary_df[binary_df['region'].notna()].copy()
    region_dummies = pd.get_dummies(region_sub['region'], prefix='reg', drop_first=True)
    for col in region_dummies.columns:
        region_sub[col] = region_dummies[col].astype(float)
    region_dummy_cols = list(region_dummies.columns)

    if len(region_dummy_cols) > 0:
        r = run_logit(region_sub, 'is_peg_binary', x_demo + controls,
                      'Region FE', dummy_vars=region_dummy_cols)
        if r: results.append(r)

    # M3: With income group dummies
    if 'income_group' in binary_df.columns and binary_df['income_group'].notna().sum() > 100:
        income_sub = binary_df[binary_df['income_group'].notna()].copy()
        income_dummies = pd.get_dummies(income_sub['income_group'],
                                        prefix='inc', drop_first=True)
        for col in income_dummies.columns:
            income_sub[col] = income_dummies[col].astype(float)
        income_dummy_cols = list(income_dummies.columns)

        if len(income_dummy_cols) > 0:
            r = run_logit(income_sub, 'is_peg_binary', x_demo + controls,
                          'Income FE', dummy_vars=income_dummy_cols)
            if r: results.append(r)

    # M4: Region + income
    if (len(region_dummy_cols) > 0 and
        'income_group' in binary_df.columns and
        binary_df['income_group'].notna().sum() > 100):
        both_sub = binary_df[
            binary_df['region'].notna() &
            binary_df['income_group'].notna()
        ].copy()
        reg_d = pd.get_dummies(both_sub['region'], prefix='reg', drop_first=True)
        inc_d = pd.get_dummies(both_sub['income_group'], prefix='inc', drop_first=True)
        for col in reg_d.columns:
            both_sub[col] = reg_d[col].astype(float)
        for col in inc_d.columns:
            both_sub[col] = inc_d[col].astype(float)
        both_dummy_cols = list(reg_d.columns) + list(inc_d.columns)

        r = run_logit(both_sub, 'is_peg_binary', x_demo + controls,
                      'Region+Income', dummy_vars=both_dummy_cols)
        if r: results.append(r)

    # M5: OECD subsample (no region FE — near-perfect separation with region)
    oecd_binary = binary_df[binary_df['iso3'].isin(OECD)].copy()
    r = run_logit(oecd_binary, 'is_peg_binary', x_demo + controls, 'OECD only')
    if r: results.append(r)

    # Write output
    lines = ["# Logit Robustness: Region and Income Fixed Effects\n"]
    lines.append("Peg vs Float (intermediate regimes dropped). "
                 "Columns report marginal effects at means.\n")

    if results:
        # Only report demographic variables + controls, not the FE dummies
        report_vars = x_demo + controls
        model_labels = [r['model'] for r in results]
        lines.append("| Variable | " + " | ".join(model_labels) + " |")
        lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")

        for var in report_vars:
            coef_row = f"| {var} |"
            se_row = "| |"
            for r in results:
                if f'{var}_coef' in r:
                    c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                    coef_row += f" {c} |"
                    se_row += f" {s} |"
                else:
                    coef_row += " |"
                    se_row += " |"
            lines.append(coef_row)
            lines.append(se_row)

        lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
        for stat_name, stat_key in [('N', 'n_obs'), ('Pseudo-R²', 'r_squared'),
                                     ('Countries', 'n_countries')]:
            row = f"| {stat_name} |"
            for r in results:
                if stat_key == 'r_squared':
                    row += f" {r[stat_key]:.4f} |"
                else:
                    row += f" {r[stat_key]} |"
            lines.append(row)
        if len(results) >= 5:
            lines.append("| Region FE | No | Yes | No | Yes | No |")
            lines.append("| Income FE | No | No | Yes | Yes | No |")
        elif len(results) == 4:
            lines.append("| Region FE | No | Yes | No | Yes |")
            lines.append("| Income FE | No | No | Yes | Yes |")

    lines.append("\n*Pooled logit with Hessian-based standard errors. "
                 "Marginal effects at means reported.*")
    lines.append("*Region FE: World Bank regional dummies (EAP, ECA, LAC, "
                 "MENA, NAC, SAS, SSA). Income FE: low/lower-middle, "
                 "upper-middle, high (from GDP per capita PPP).*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT / "phase11_logit_robustness.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    return results


# ═══════════════════════════════════════════════════════════════════════
# PART 3: PROJECTION SENSITIVITY (±1SE BANDS ON STRAIN)
# ═══════════════════════════════════════════════════════════════════════

def projection_sensitivity(df, baseline_result):
    """Compute regime strain with ±1SE coefficient bands."""
    print("\n" + "=" * 70)
    print("PART 3: PROJECTION SENSITIVITY — ±1SE STRAIN BANDS")
    print("=" * 70)

    if baseline_result is None:
        print("  No baseline result, skipping")
        return

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    z_dev_vars = [f'{z}_dev' for z in z_vars]

    # Extract coefficients and SEs
    coefs = {v: baseline_result[f'{v}_coef'] for v in z_dev_vars}
    ses = {v: baseline_result[f'{v}_se'] for v in z_dev_vars}

    print("  Baseline coefficients:")
    for v in z_dev_vars:
        print(f"    {v}: {coefs[v]:.2f} ± {ses[v]:.2f}")

    # Load projections
    fp = pd.read_csv(MULTILATERAL_DATA / "full_panel.csv")
    ez_fp = fp[fp['iso3'].isin(EUROZONE_ISO3)].copy()

    decades = [2020, 2030, 2040, 2050, 2060]
    proj_rows = []
    for iso3 in sorted(EUROZONE_ISO3):
        cdata = ez_fp[ez_fp['iso3'] == iso3]
        for yr in decades:
            row_data = cdata[cdata['year'] == yr]
            if len(row_data) == 0:
                continue
            proj_rows.append({
                'iso3': iso3, 'year': yr,
                'Z_1': row_data['Z_1'].values[0],
                'Z_2': row_data['Z_2'].values[0],
                'Z_3': row_data['Z_3'].values[0],
            })
    proj_df = pd.DataFrame(proj_rows)

    if len(proj_df) == 0:
        print("  No projection data")
        return

    # Compute deviations from EMU mean per decade
    for z in z_vars:
        emu_mean = proj_df.groupby('year')[z].mean().rename(f'{z}_emu_mean')
        proj_df = proj_df.merge(emu_mean, on='year', how='left')
        proj_df[f'{z}_dev'] = proj_df[z] - proj_df[f'{z}_emu_mean']

    # Compute strain with point estimate, +1SE, -1SE
    # Only shift Z1_dev coefficient by ±1SE (main driver), hold Z2/Z3 at point
    strain_rows = []
    for _, row in proj_df.iterrows():
        point = sum(coefs[v] * row[v] for v in z_dev_vars)
        # Vary only Z1_dev coefficient; hold Z2_dev, Z3_dev at point estimates
        z1_term_upper = (coefs['Z_1_dev'] + ses['Z_1_dev']) * row['Z_1_dev']
        z1_term_lower = (coefs['Z_1_dev'] - ses['Z_1_dev']) * row['Z_1_dev']
        other_terms = sum(coefs[v] * row[v] for v in z_dev_vars if v != 'Z_1_dev')
        upper = max(z1_term_upper, z1_term_lower) + other_terms
        lower = min(z1_term_upper, z1_term_lower) + other_terms
        strain_rows.append({
            'iso3': row['iso3'], 'year': int(row['year']),
            'ca_dev_point': point,
            'ca_dev_upper': max(upper, lower),  # more negative = larger deficit
            'ca_dev_lower': min(upper, lower),
            'Z_1_dev': row['Z_1_dev'],
        })

    strain_df = pd.DataFrame(strain_rows)

    # Write output — focus on 2040 for key countries
    lines = ["# Projection Sensitivity: ±1SE Coefficient Bands\n"]
    lines.append("Regime strain (predicted CA/GDP deviation from EMU mean) "
                 "using within-EMU Z deviation coefficients.\n")
    lines.append("Point estimate uses baseline coefficients. "
                 "Bands shift all Z_dev coefficients by ±1 standard error.\n")

    # Table: country × decade with bands
    lines.append("## 2040 Strain with Uncertainty Bands\n")
    lines.append("| Country | Z1_dev | Point Est. | Lower (−1SE) | Upper (+1SE) |")
    lines.append("|:---|---:|---:|---:|---:|")

    s2040 = strain_df[strain_df['year'] == 2040].sort_values('ca_dev_point')
    for _, row in s2040.iterrows():
        lines.append(f"| {row['iso3']} | {row['Z_1_dev']:+.3f} | "
                    f"{row['ca_dev_point']:+.2f} | "
                    f"{row['ca_dev_lower']:+.2f} | "
                    f"{row['ca_dev_upper']:+.2f} |")

    # Full decade table (point estimates only — for comparison with original)
    lines.append("\n## Full Projection (Point Estimates)\n")
    pivot = strain_df.pivot(index='iso3', columns='year', values='ca_dev_point')
    pivot.columns = [str(int(c)) for c in pivot.columns]
    lines.append("| Country | " + " | ".join(pivot.columns) + " |")
    lines.append("|:---|" + "|".join(["---:" for _ in pivot.columns]) + "|")
    for iso3 in s2040['iso3']:
        if iso3 not in pivot.index:
            continue
        row_str = f"| {iso3} |"
        for col in pivot.columns:
            val = pivot.loc[iso3, col]
            if pd.notna(val):
                row_str += f" {val:+.2f} |"
            else:
                row_str += " |"
        lines.append(row_str)

    # Max deficit/surplus by decade
    lines.append("\n## Extremes by Decade\n")
    lines.append("| Decade | Max Deficit (Country) | Max Surplus (Country) | "
                 "Total Spread |")
    lines.append("|---:|:---|:---|---:|")
    for yr in decades:
        syr = strain_df[strain_df['year'] == yr]
        if len(syr) == 0:
            continue
        min_row = syr.loc[syr['ca_dev_point'].idxmin()]
        max_row = syr.loc[syr['ca_dev_point'].idxmax()]
        spread = max_row['ca_dev_point'] - min_row['ca_dev_point']
        lines.append(f"| {yr} | {min_row['ca_dev_point']:+.2f} "
                    f"({min_row['iso3']}) | {max_row['ca_dev_point']:+.2f} "
                    f"({max_row['iso3']}) | {spread:.2f} |")

    lines.append("\n*Coefficients from within-EMU Z deviation regression "
                 f"(N={baseline_result['n_obs']}). "
                 "Z projections from UN WPP medium variant. "
                 "Bands represent coefficient uncertainty only, not "
                 "demographic projection uncertainty.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT / "phase11_projection_sensitivity.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ═══════════════════════════════════════════════════════════════════════
# PART 4: EMU WITHOUT HUBS + POST-2010 ROBUSTNESS
# ═══════════════════════════════════════════════════════════════════════

EMU_HUBS = {'LUX', 'IRL', 'MLT', 'CYP', 'NLD', 'BEL'}


def emu_hub_exclusion(df):
    """4-column EMU robustness: baseline, excl hubs, post-2010, both."""
    print("\n" + "=" * 70)
    print("PART 4: EMU WITHOUT HUBS + POST-2010 ROBUSTNESS")
    print("=" * 70)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Build EMU samples
    ez_full = get_ez_df(df)
    ez_no_hubs = ez_full[~ez_full['iso3'].isin(EMU_HUBS)].copy()
    ez_post2010 = ez_full[ez_full['year'] >= 2010].copy()
    ez_no_hubs_post = ez_no_hubs[ez_no_hubs['year'] >= 2010].copy()

    samples = [
        ('Baseline EMU', ez_full),
        ('Excl Hubs', ez_no_hubs),
        ('Post-2010', ez_post2010),
        ('Excl Hubs + Post-2010', ez_no_hubs_post),
    ]

    for label, sdf in samples:
        print(f"\n  {label}: {len(sdf)} obs, {sdf['iso3'].nunique()} countries, "
              f"years {sdf['year'].min()}-{sdf['year'].max()}")

    # ── Z levels → CA/GDP ──
    print("\n  --- Z Levels → CA/GDP ---")
    level_results = []
    for label, sdf in samples:
        r = run_panel_gls(sdf, 'ca_gdp', z_vars + controls, label)
        if r: level_results.append(r)

    # ── Z deviations → CA deviation ──
    print("\n  --- Z Deviations → CA Deviation ---")
    dev_results = []
    for label, sdf in samples:
        # Compute within-sample EMU means
        emu_means = sdf.groupby('year')[z_vars + ['ca_gdp']].mean()
        emu_means.columns = [f'{c}_emu_mean' for c in emu_means.columns]
        sdf_dev = sdf.merge(emu_means, on='year', how='left')
        for z in z_vars:
            sdf_dev[f'{z}_dev'] = sdf_dev[z] - sdf_dev[f'{z}_emu_mean']
        sdf_dev['ca_dev'] = sdf_dev['ca_gdp'] - sdf_dev['ca_gdp_emu_mean']

        z_dev_vars = [f'{z}_dev' for z in z_vars]
        r = run_panel_gls(sdf_dev, 'ca_dev', z_dev_vars + controls,
                          f'{label} (dev)')
        if r:
            # Also compute standardized effects
            sub = sdf_dev[z_dev_vars + ['ca_dev']].dropna()
            z1_sd = sub['Z_1_dev'].std()
            z1_iqr = sub['Z_1_dev'].quantile(0.75) - sub['Z_1_dev'].quantile(0.25)
            r['z1_dev_sd'] = z1_sd
            r['z1_dev_iqr'] = z1_iqr
            r['sd_effect'] = r['Z_1_dev_coef'] * z1_sd
            r['iqr_effect'] = r['Z_1_dev_coef'] * z1_iqr
            print(f"    Z1_dev SD={z1_sd:.4f}, IQR={z1_iqr:.4f}")
            print(f"    1-SD effect: {r['sd_effect']:.2f} pp, "
                  f"IQR effect: {r['iqr_effect']:.2f} pp")
            dev_results.append(r)

    # ── Dispersion metrics (Path C) ──
    print("\n  --- Cross-Sectional CA Dispersion Explained ---")
    for label, sdf in samples:
        ca_by_year = sdf.groupby('year')['ca_gdp'].agg(['std', 'count'])
        avg_disp = ca_by_year['std'].mean()
        print(f"    {label}: avg annual CA/GDP cross-sectional SD = {avg_disp:.2f}")

    # ── Write output ──
    lines = ["# EMU Robustness: Hub Exclusion and Post-2010 Sample\n"]
    lines.append("Financial center / tax hub countries excluded: "
                 f"{', '.join(sorted(EMU_HUBS))}\n")

    # Table 1: Z levels → CA/GDP
    lines.append("## Panel A: Z Levels → CA/GDP\n")
    if level_results:
        model_labels = [r['model'] for r in level_results]
        all_vars = z_vars + controls
        lines.append("| Variable | " + " | ".join(model_labels) + " |")
        lines.append("|:---|" + "|".join(["---:" for _ in level_results]) + "|")
        for var in all_vars:
            coef_row = f"| {var} |"
            se_row = "| |"
            for r in level_results:
                if f'{var}_coef' in r:
                    c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                    coef_row += f" {c} |"
                    se_row += f" {s} |"
                else:
                    coef_row += " |"
                    se_row += " |"
            lines.append(coef_row)
            lines.append(se_row)
        lines.append("|:---|" + "|".join(["---:" for _ in level_results]) + "|")
        for stat, key in [('N', 'n_obs'), ('R²', 'r_squared'), ('Countries', 'n_countries')]:
            row = f"| {stat} |"
            for r in level_results:
                row += f" {r[key]:.4f} |" if key == 'r_squared' else f" {r[key]} |"
            lines.append(row)

    # Table 2: Z deviations → CA deviation (with standardized effects)
    lines.append("\n## Panel B: Z Deviations → CA Deviation\n")
    if dev_results:
        z_dev_vars = [f'{z}_dev' for z in z_vars]
        all_vars = z_dev_vars + controls
        model_labels = [r['model'] for r in dev_results]
        lines.append("| Variable | " + " | ".join(model_labels) + " |")
        lines.append("|:---|" + "|".join(["---:" for _ in dev_results]) + "|")
        for var in all_vars:
            coef_row = f"| {var} |"
            se_row = "| |"
            for r in dev_results:
                if f'{var}_coef' in r:
                    c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                    coef_row += f" {c} |"
                    se_row += f" {s} |"
                else:
                    coef_row += " |"
                    se_row += " |"
            lines.append(coef_row)
            lines.append(se_row)
        lines.append("|:---|" + "|".join(["---:" for _ in dev_results]) + "|")
        for stat, key in [('N', 'n_obs'), ('R²', 'r_squared'), ('Countries', 'n_countries')]:
            row = f"| {stat} |"
            for r in dev_results:
                row += f" {r[key]:.4f} |" if key == 'r_squared' else f" {r[key]} |"
            lines.append(row)

        # Standardized effects summary
        lines.append("\n## Panel C: Standardized Effects\n")
        lines.append("| Metric | " + " | ".join(model_labels) + " |")
        lines.append("|:---|" + "|".join(["---:" for _ in dev_results]) + "|")
        for metric_name, metric_key in [
            ('Z1_dev coefficient', 'Z_1_dev_coef'),
            ('Z1_dev SE', 'Z_1_dev_se'),
            ('Z1_dev SD', 'z1_dev_sd'),
            ('Z1_dev IQR', 'z1_dev_iqr'),
            ('1-SD effect (pp)', 'sd_effect'),
            ('IQR effect (pp)', 'iqr_effect'),
        ]:
            row = f"| {metric_name} |"
            for r in dev_results:
                if metric_key in r:
                    row += f" {r[metric_key]:.2f} |"
                else:
                    row += " |"
            lines.append(row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "EMU members post-accession.*")
    lines.append(f"*Hub countries excluded: {', '.join(sorted(EMU_HUBS))} "
                 "(financial centers / tax hubs with volatile CA/GDP).*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT / "phase11_emu_hub_exclusion.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    return dev_results


# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("PHASE 11: REVIEWER #2 RESPONSE — TRILEMMA PAPER")
    print("=" * 70)

    df = pd.read_csv(DATA / "trilemma_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # Part 1: EMU magnitude robustness
    baseline_result = emu_magnitude_robustness(df)

    # Part 2: Logit hardening
    logit_hardening(df)

    # Part 3: Projection sensitivity
    projection_sensitivity(df, baseline_result)

    # Part 4: EMU hub exclusion + post-2010
    emu_hub_exclusion(df)

    print("\n" + "=" * 70)
    print("PHASE 11 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
